#
# Copyright 2014, NICTA
#
# This software may be distributed and modified according to the terms of
# the BSD 2-Clause license. Note that NO WARRANTY is provided.
# See "LICENSE_BSD2.txt" for details.
#
# @TAG(NICTA_BSD)
#

import braces
import re
import sys
import copy
import os

class Call:
	def __init__(self):
		self.restr = None
		self.decls_only = False
		self.instanceproofs = False
		self.bodies_only = False
		self.archdefs = False

def parse (call):
	"""Parses a file."""
	set_global(call)

	defs = get_defs (call.filename)

	lines = get_lines (defs, call)

	lines = perform_module_redirects (lines, call)

	return ['%s\n' % line for line in lines]

def set_global (_call):
	global call
	call = _call
	global filename
	filename = _call.filename

file_defs = {}

def splitList (list, pred):
	"""Splits a list according to pred."""
	result = []
	el = []
	for l in list:
		if pred (l):
			if el != []:
				result.append(el)
				el = []
		else:
			el.append(l)
	if el != []:
		result.append(el)
	return result


def takeWhile (list, pred):
	"""Returns the initial portion of the list where each 
element matches pred"""
	limit = 0
	    
	for l in list:
		if pred (l):
			limit = limit + 1
		else:
			break
	return list[0:limit]

def emptyList (l):
	return l == ''

def get_defs (filename):
	if filename in file_defs:
		return file_defs[filename]

	cmdline = os.environ['L4CPP']	
	f = os.popen ('cpp -Wno-invalid-pp-token -traditional-cpp %s %s' % (cmdline, filename))
	input = [line.rstrip() for line in f]
	f.close()
	defs = top_transform (input)

	file_defs[filename] = defs
	return defs

def top_transform (input):
	"""Top level transform, deals with lhs artefacts, divides
	the code up into a series of seperate definitions, and
	passes these definitions through the definition transforms."""
	to_process = []
	comments = []
	for n, line in enumerate(input):
		if '\t' in line:
			sys.stderr.write('WARN: tab in line %d, %s.\n' % \
				(n, filename))
		if line.startswith ('> '):
			if '--' in line:
				line = line.split('--')[0].strip()
			
			if line[2:].strip() == '':
				comments.append ((n, 'C', ''))
			elif line.startswith ('> {-#'):
				comments.append ((n, 'C', '(*' + line + '*)'))
			else:
				to_process.append ((line[2:], n))
		else:
			if line.strip():
				comments.append ((n, 'C', '(*' + line + '*)'))
			else:
				comments.append ((n, 'C', ''))
	def_tree = offside_tree (to_process)
	defs = create_defs (def_tree)
	defs = group_defs (defs)

# 	Forget about the comments for now
	
#	defs_plus_comments = [(d['line'], d) for d in defs] + comments
#	defs_plus_comments.sort()
#	defs = []
#	prev_comments = []
#	for term in defs_plus_comments:
#		if term[1] == 'C':
#			prev_comments.append(term[2])
#		else:
#			d = term[1]
#			d['comments'] = prev_comments
#			defs.append(d)
#			prev_comments = []
	
	# apply def_transform and cut out any None return values
	defs = [defs_transform (d) for d in defs]
	defs = [d for d in defs if d != None]

	defs = ensure_type_ordering (defs)

	return defs

def get_lines(defs, call):
	"""Gets the output lines needed for this call from
	all the potential output generated at parse time."""

	if call.restr:
		defs = [d for d in defs if d['type'] == 'comments' \
					or call.restr(d)]

	output = []
	for d in defs:
		lines = def_lines (d, call)
		if lines:
				output.extend (lines)
				output.append ('')
	
	return output

def offside_tree (input):
	"""Breaks lines up into a tree based on the offside rule.
	Each line gets as children the lines following it up until
	the next line whose indent is less or equal."""
	if input == []:
		return []
	head, head_n = input[0]
	head_indent = len(head) - len(head.lstrip())
	children = []
	result = []
	for line, n in input[1:]:
		indent = len(line) - len(line.lstrip())
		if indent <= head_indent:
			result.append ((head, head_n, offside_tree(children)))
			head, head_n, head_indent = (line, n, indent)
			children = []
		else:
			children.append ((line, n))
	result.append ((head, head_n, offside_tree(children)))

	return result

def discard_n (tree):
	"""Takes a tree containing tuples (line, n, children) and
	discards the n terms, returning a tree with tuples
	(line, children)"""
	result = []
	for (line, n, children) in tree:
		result.append ((line, discard_n (children)))
	return result

def flatten_tree (tree):
	"""Returns a tree to the set of numbered lines it was
	drawn from."""
	result = []
	for (line, children) in tree:
		result.append (line)
		result.extend (flatten_tree (children))

	return result

def create_defs (tree):
	defs = [create_def(elt) for elt in tree]
	defs = [d for d in defs if d != None]

	return defs

def group_defs (defs):
	"""Takes a file broken into a series of definitions, and locates
	multiple definitions of constants, caused by type signatures or
	pattern matching, and accumulates to a single object per genuine
	definition"""
	defgroups = []
	defined = ''
	for d in defs:
		this_defines = d.get('actual_fn', d['defined'])
		if d['type'] != 'definitions':
			this_defines = ''
		if this_defines == defined and this_defines:
			defgroups[-1]['body'].extend (d['body']) 
		else:
			defgroups.append (d)
			defined = this_defines

	return defgroups

def create_def (elt):
	"""Takes an element of an offside tree and creates
	a definition object."""
	(line, n, children) = elt
	children = discard_n(children)
	return create_def_2 (line, children, n)

def create_def_2 (line, children, n):
	d = {'body': [(line, children)], 'line': n}
	lead = line.split(None, 3)
	if lead[0] in {'import': True, 'module': True, 'class': True}:
		return
	elif lead[0] == 'instance':
		d['type'] = 'instance'
		d['defined'] = lead[2]
		return d
	elif lead[0] in {'type':True, 'newtype':True, 'data':True}:
		d['type'] = 'newtype'
		d['defined'] = lead[1]
		return d
	else:
		d['type'] = 'definitions'
		d['defined'] = lead[0]
		return d

def get_primrecs():
	f = open('primrecs')
	keys = [line.strip() for line in f]
	pairs = [(key, True) for key in keys if key != '']
	return dict(pairs)

primrecs = get_primrecs()

def defs_transform (d):
	"""Transforms the set of definitions for a function. This
	may include its type signature, and may include the special
	case of multiple definitions."""
	# the first tokens of the first line in the first definition
	if d['type'] == 'newtype':
		return newtype_transform (d)
	elif d['type'] == 'instance':
		return instance_transform (d)

	lead = d['body'][0][0].split(None, 2)
	if lead[1] == '::':
		d['sig'] = type_sig_transform (d['body'][0])
		d['body'] = d['body'][1:]
	
	if d['defined'] in primrecs:
		return primrec_transform (d)

	if len(d['body']) > 1:
		d['body'] = pattern_match_transform (d['body'])

	if len(d['body']) == 0:
		print
		print d
		assert 0

	d['body'] = body_transform(d['body'], d['defined'], d.get('sig', None))
	return d

def def_lines (d, call):
	"""Produces the set of lines associated with a definition."""
	if 'arch_promotion' in d and not call.archdefs:
		print d
		return []

	if call.all_bits:
		L = []
		if 'comments' in d:
			L.extend (flatten_tree (d['comments']))
			L.append ('')
		if d['type'] == 'definitions':
			L.append('definition')
			if 'sig' in d:
				L.extend (flatten_tree ([d['sig']]))
				L.append ('where')
				L.extend (flatten_tree (d['body']))
		elif d['type'] == 'newtype':
			L.extend (flatten_tree (d['body']))
		if 'instance_proofs' in d:
			L.extend (flatten_tree (d['instance_proofs']))
			L.append ('')
		if 'instance_extras' in d:
			L.extend (flatten_tree (d['instance_extras']))
			L.append ('')
		return L
			
	if call.instanceproofs:
		if not call.bodies_only:
			instance_proofs = flatten_tree (d.get('instance_proofs', []))
		else:
			instance_proofs = []

		if not call.decls_only:
			instance_extras = flatten_tree (d.get('instance_extras', []))
		else:
			instance_extras = []

		newline_needed = len(instance_proofs) > 0 and len(instance_extras) > 0
		return instance_proofs + ([''] if newline_needed else []) + instance_extras
	
	if call.body:
		return get_lambda_body_lines (d)

	comments = d.get('comments', [])
	try:
		typesig = flatten_tree ([d['sig']])
	except:
		typesig = []
	body = flatten_tree (d['body'])
	type = d['type']
	
	if type == 'definitions':
		if call.decls_only:
			if typesig:
				return comments + ['consts'] + typesig
			else:
				return []
		elif call.bodies_only:
			if 'sig' in d:
				if 'actual_fn' in d:
					defname = '%s_def' % d['actual_fn']
				else:
					defname = '%s_def' % d['defined']
				if 'primrec' in d:
					print 'warning body-only primrec:'
					print body[0]
					return comments + ['primrec'] + body
				return comments + ['defs %s:' % defname] + body
			else:
				return comments + ['definition'] + body
		else:
			if 'primrec' in d:
				return comments + ['primrec'] + typesig \
					+ ['where'] + body
			if typesig:
				return comments + ['definition'] + typesig + ['where'] + body
			else:
				return comments + ['definition'] + body
	elif type == 'comments':
		return comments
	elif type == 'newtype':
		if not call.bodies_only:
			return body

def type_sig_transform (tree_element):
	"""Performs transformations on a type signature line
	preceding a function declaration or some such."""

	line = reduce_to_single_line (tree_element)
	(pre, post) = line.split('::')
	result = type_transform (post)
	if '[pp' in result:
		print line
		print pre
		print post
		print result
		assert 0
	line = pre + ':: "' + result + '"'

	return (line, [])

ignore_classes = {'Error':1}
hand_classes = {'Bits': ['HS_bit'],
    'Num': ['minus', 'one', 'zero', 'plus', 'numeral']}

def type_transform (string):
	"""Performs transformations on a type signature, whether
	part of a type signature line or occuring in a function."""
	
	# deal with type classes by recursion
	bits = string.split ('=>', 1)
	if len(bits) == 2:
		lhs = bits[0].strip()
		if lhs.startswith('(') and lhs.endswith(')'):
			instances = lhs[1:-1].split(',')
			string = ' => '.join(instances + [bits[1]])
		else:
			instances = [lhs]
		var_annotes = {}
		for instance in instances:
			(name, var) = instance.split()
			if name in ignore_classes:
				continue
			if name in hand_classes:
				names = hand_classes[name]
			else:
				names = [type_conv(name)]
			var = "'" + var
                        var_annotes.setdefault (var, [])
                        var_annotes[var].extend(names)
		transformed = type_transform(bits[1])
		for (var, insts) in var_annotes.iteritems():
                        if len (insts) == 1:
                                newvar = '(%s :: %s)' % (var, insts[0])
                        else:
                                newvar = '(%s :: {%s})' % (var, ', '.join(insts))
			transformed = newvar.join (transformed.split (var, 1))
		return transformed

	# get rid of (), insert Unit, which converts to unit
	string = 'Unit'.join(string.split('()'))

	# divide up by -> or by , then divide on space.
	# apply everything locally then work back up
	bstring = braces.str (string, '(', ')')
	bits = bstring.split('->')
	r = ' \<Rightarrow> '
	if len(bits) == 1:
		bits = bstring.split(',')
		r = ' * '
	result = [type_bit_transform (bit) for bit in bits]
	return r.join(result)

def type_bit_transform (bit):
	s = str(bit).strip()
	if s.startswith ('['):
		# handling this properly is hard.
		assert s.endswith (']')
		bit2 = braces.str(s[1:-1], '(', ')')
		return '%s list' % type_bit_transform(bit2)
	bits = bit.split(None, braces = True)
	if str(bits[0]) == 'PPtr':
		assert len(bits) == 2
		return 'machine_word'
	if len(bits) > 1 and bits[1].startswith ('['):
		assert bits[-1].endswith (']')
		arg = ' '.join([str (bit) for bit in bits[1:]])[1:-1]
		arg = type_transform(arg)
		return ' '.join([arg, 'list', str (type_conv(bits[0]))])
	bits = [type_conv(bit) for bit in bits]
	bits = constructor_reversing (bits)
	bits = [bit.map (type_transform) for bit in bits]
	strs = [str (bit) for bit in bits]
	return ' '.join(strs)

def reduce_to_single_line (tree_element):
	(line, children) = tree_element
	for child in children:
		cline = reduce_to_single_line (child)
		line = line + ' ' + cline.strip()
	return line

conv_table = {'Maybe' : 'option', 'Bool' : 'bool', 'Word' : 'machine_word',
		'Int' : 'nat', 'String' : 'unit list'}

def type_conv (string):
	"""Converts a type used in Haskell to our equivalent"""
	if string.startswith('('):
		# ignore compound types, type_transform will descend into em
		result = string
	elif '.' in string:
		bits = string.split('.')
		typename = bits[-1]
		module = reduce (lambda x, y: x + '.' + y, bits[:-1])
		typename = type_conv (typename)
		result = module + '.' + typename
	elif string[0].islower():
		result = "'%s" % string
	elif string[0] == '[' and string[-1] == ']':
		inner = type_conv (string[1:-1])
		result = '%s list' % inner
	elif str(string) in conv_table:
		result = conv_table[str(string)]
	else:
		was_lower = False
		s = ''
		for c in string:
			if c.isupper() and was_lower:
				s = s + '_' + c.lower()
			else:
				s = s + c.lower()
			was_lower = c.islower()
		result = s
		conv_table[str(string)] = result
	
	return braces.clone (result, string)

def constructor_reversing (tokens):
	if len(tokens) < 2:
		return tokens
	elif len(tokens) == 2:
		return [tokens[1], tokens[0]]
	elif tokens[0] == '[' and tokens[2] == ']':
		return [tokens[1], braces.str('list', '(', ')')]
	elif len(tokens) == 4 and tokens[1] == '[' and tokens[3] == ']':
		listToken = braces.str('(List %s)' % tokens[2], '(', ')')
		return [listToken, tokens[0]]
	elif tokens[0] == 'array':
		arrow_token = braces.str('\<Rightarrow>', '(', ')')
		return [tokens[1], arrow_token, tokens[2]]
	elif tokens[0] == 'either':
		plus_token = braces.str('+', '(', ')')
		return [tokens[1], plus_token, tokens[2]]
	elif len(tokens) == 5 and tokens[2] == '[' and tokens[4] == ']':
		listToken = braces.str('(List %s)' % tokens[3], '(', ')')
		lbrack = braces.str('(', '+', '+')
		rbrack = braces.str(')', '+', '+')
		comma = braces.str(',', '+', '+')
		return [lbrack, tokens[1], comma, listToken, rbrack, tokens[0]]
	elif len(tokens) == 3:
		# here comes a fudge
		lbrack = braces.str('(', '+', '+')
		rbrack = braces.str(')', '+', '+')
		comma = braces.str(',', '+', '+')
		return [lbrack, tokens[1], comma, tokens[2], rbrack, tokens[0]]
	else:
		print "Error parsing " + filename
		print "Can't deal with %s" % tokens
		sys.exit(1)

def newtype_transform (d):
	"""Takes a Haskell style newtype/data type declaration, whose
	options are divided with | and each of whose options has named
	elements, and forms a datatype statement and definitions for
	the named extractors and their update functions."""
	if len(d['body']) != 1:
		print '--- newtype long body ---'
		print d
	[(line, children)] = d['body']

	if children and children[-1][0].lstrip().startswith('deriving'):
		l = reduce_to_single_line (children[-1])
		children = children[:-1]
		r = re.compile (r"[,\s\(\)]+")
		bits = r.split(l)
		bits = [bit for bit in bits if bit and bit != 'deriving']
		d['deriving'] = bits

	line = reduce_to_single_line ((line, children))

	bits = line.split(None, 1)
	op = bits[0]
	line = bits[1]
	bits = line.split('=', 1)
	header = type_conv (bits[0].strip())
	d['typename'] = header
	d['typedeps'] = {}
	if len(bits) == 1:
		# line of form 'data Blah' introduces unknown type?
		d['body'] = [('typedecl %s' % header, [])]
		all_type_arities[header] = [] # HACK
		return d
	line = bits[1]
	
	if op == 'type':
		return typename_transform (line, header, d)
	elif line.find('{') == -1:
		return simple_newtype_transform (line, header, d)
	else:
		return named_newtype_transform (line, header, d)

def typename_transform (line, header, d):
	try:
		[oldtype] = line.split()
	except:
		sys.stderr.write ('Warning: type assignment %s\n' % d['body'])
		return
	if oldtype.startswith('Data.Word.Word'):
		oldtype = oldtype[10:]
	oldtype = type_conv(oldtype)
	bits = oldtype.split()
	for bit in bits:
		d['typedeps'][bit] = 1
	lines = [
		'type_synonym %s = "%s"' % (header, oldtype),
#		'translations (* TYPE 1 *)',
#		'"%s" <=(type) "%s"' % (oldtype, header)
	]
	d['body'] = [(line, []) for line in lines]
	return d

dontwrap = {'asidpool':1}

def simple_newtype_transform (line, header, d):
	lines = []
	arities = []
	for i, bit in enumerate (line.split ('|')):
		braced = braces.str (bit, '(', ')')
		bits = braced.split()
		if len(bits) == 2:
			last_lhs = bits[0]

		if i == 0:
			l = '    %s' % bits[0]
		else:
			l = '  | %s' % bits[0]
		for bit in bits[1:]:
			if bit.startswith ('('):
				bit = bit[1:-1]
				typename = type_transform (str(bit))
			else:
				typename = type_conv (str(bit))
			if len(bits) == 2:
				last_rhs = typename
			if ' ' in typename:
				typename = '"%s"' % typename
			l = l + ' ' + typename
			d['typedeps'][typename] = 1
		lines.append(l)

		arities.append ((str(bits[0]), len(bits[1:])))
	
	if (dict (arities)).values() == [1] and header not in dontwrap:
		return type_wrapper_type (header, last_lhs, last_rhs, d)

	d['body'] = [('datatype %s =' % header,
			[(line, []) for line in lines])]

	set_instance_proofs (header, arities, d)

	return d

all_constructor_args = {}

def named_newtype_transform (line, header, d):
	bits = line.split('|')

	constructors = dict ([get_type_map (bit) for bit in bits])
	all_constructor_args.update(constructors)
	
	lines = []
	for i, (name, map) in enumerate (constructors.iteritems()):
		if i == 0:
			l = '    %s' % name
		else:
			l = '  | %s' % name
		for name, type in map:
			if len(type.split()) == 1 and '(' not in type:
				l = l + ' ' + type
			else:
				l = l + ' "' + type + '"'
			for bit in type.split():
				d['typedeps'][bit] = 1
		lines.append(l)
	
	names = {}
	types = {}
	for cons, map in constructors.iteritems():
		for i, (name, type) in enumerate(map):
			names.setdefault (name, {})
			names[name][cons] = i
			types[name] = type
	
	for name, map in names.iteritems():
		lines.append ('')
		lines.extend (named_extractor_definitions (name, map, types[name],
					header, constructors))
	
	for name, map in names.iteritems():
		lines.append ('')
		lines.extend (named_update_definitions (name, map, types[name],
					header, constructors))

	for name, map in constructors.iteritems():
		if map == []:
			continue
		lines.append ('')
		lines.extend (named_constructor_translation (name, map, header))
	
	if len(constructors) > 1:
		for name, map in constructors.iteritems():
			lines.append ('')
			check = named_constructor_check (name, map, header)
			lines.extend (check)

	if len(constructors) == 1:
		for ex_name, _ in names.iteritems():
			for up_name, _ in names.iteritems():
				lines.append ('')
				lines.extend (named_extractor_update_lemma (ex_name, up_name))

	arities = [(name, len(map))
			for (name, map) in constructors.iteritems()]

	if (dict (arities)).values() == [1]:
		[(cons, map)] = constructors.items()
		[(name, type)] = map
		return type_wrapper_type (header, cons, type, d,
			decons=(name, type))

	set_instance_proofs (header, arities, d)
	
	d['body'] = [('datatype %s =' % header,
			[(line, []) for line in lines])]
	return d

def named_extractor_update_lemma (ex_name, up_name):
	lines = []
	lines.append ('lemma %s_%s_update [simp]:' % (ex_name, up_name))

	if up_name == ex_name:
		lines.append ('  "%s (%s_update f v) = f (%s v)"' % (ex_name, up_name, ex_name))
	else:
		lines.append ('  "%s (%s_update f v) = %s v"' % (ex_name, up_name, ex_name))
		
	lines.append ('  by (cases v) simp')

	return lines

def named_extractor_definitions (name, map, type, header, constructors):
	lines = []
	lines.append ('primrec')
	lines.append ('  %s :: "%s \<Rightarrow> %s"' \
			% (name, header, type))
	lines.append ('where')
	is_first = True
	for cons, i in map.iteritems():
		if is_first:
			l = '  "%s (%s' % (name, cons)
			is_first = False
		else:
			l = '| "%s (%s' % (name, cons)		
		num = len(constructors[cons])
		for k in range (num):
			l = l + ' v%d' % k
		l = l + ') = v%d"' % i
		lines.append (l)

	return lines

def named_update_definitions (name, map, type, header, constructors):
	lines = []
	lines.append ('primrec')
	ra = '\<Rightarrow>'
	if len(type.split()) > 1:
		type = '(%s)' % type
	lines.append ('  %s_update :: "(%s %s %s) %s %s %s %s"' \
			% (name, type, ra, type, ra, header, ra, header))
	lines.append ('where')
	is_first = True
	for cons, i in map.iteritems():
		if is_first:
			l = '  "%s_update f (%s' % (name, cons)
			is_first = False
		else:
			l = '| "%s_update f (%s' % (name, cons)		
		num = len(constructors[cons])
		for k in range (num):
			l = l + ' v%d' % k
		l = l + ') = %s' % cons
		for k in range (num):
			if k == i:
				l = l + ' (f v%d)' % k
			else:
				l = l + ' v%d' % k
		l = l + '"'
		lines.append (l)

	return lines

def named_constructor_translation (name, map, header):
	lines = []
	lines.append ('abbreviation (input)')
	l = '  %s_trans :: "' % name
	for n, type in map:
		l = l + '(' + type + ') \<Rightarrow> '
	l = l + '%s" ("%s\'_ \<lparr> %s= _' % (header, name, map[0][0])
	for n, type in map[1:]:
		l = l + ', %s= _' % n
	l = l + ' \<rparr>")'
	lines.append (l)
	lines.append ('where')
	l = '  "%s_ \<lparr> %s= v0' % (name, map[0][0])
	for i, (n, type) in enumerate(map[1:]):
		l = l + ', %s= v%d' % (n, i + 1)
	l = l + ' \<rparr> == %s' % name
	for i in range (len(map)):
		l = l + ' v%d' % i
	l = l + '"'
	lines.append (l)
	
	return lines

def named_constructor_check (name, map, header):
	lines = []
	lines.append ('definition')
	lines.append ('  is%s :: "%s \<Rightarrow> bool"' % (name, header))
	lines.append ('where')
	lines.append (' "is%s v \<equiv> case v of' % name)	
	l = '    %s ' % name
	for i, _ in enumerate (map):
		l = l + 'v%d ' % i
	l = l + '\<Rightarrow> True'
	lines.append (l)
	lines.append ('  | _ \<Rightarrow> False"')

	return lines

def type_wrapper_type (header, cons, rhs, d, decons = None):
	if '\\<Rightarrow>' in d['typedeps']:
		d['body'] = [
			('(* type declaration of %s omitted *)' % header, [])
			]
		return d
	lines = [
		'type_synonym %s = "%s"' % (header, rhs),
#		'translations (* TYPE 2 *)',
#		'"%s" <=(type) "%s"' % (header, rhs),
		'',
		'definition',
		'  %s :: "%s \\<Rightarrow> %s"' % (cons, header, header),
		'where %s_def[simp]:' % cons,
		' "%s \\<equiv> id"' % cons,
	]
	if decons:
		(decons, decons_type) = decons
		lines.extend ([
		'',
		'definition',
		'  %s :: "%s \\<Rightarrow> %s"' % (decons, header, header),
		'where',
		'  %s_def[simp]:' % decons,
		' "%s \\<equiv> id"' % decons,
		'',
		'definition'
		'  %s_update :: "(%s \\<Rightarrow> %s) \\<Rightarrow> %s \\<Rightarrow> %s"' \
			% (decons, header, header, header, header),
		'where',
		'  %s_update_def[simp]:' % decons,
		' "%s_update f y \<equiv> f y"' % decons,
		''
		])
		lines.extend (
			named_constructor_translation
				(cons, [(decons, decons_type)], header)
		)
	
	d['body'] = [(line, []) for line in lines]
	return d

def instance_transform (d):
	[(line, children)] = d['body']
	bits = line.split(None, 3)
	assert bits[0] == 'instance'
	classname = bits[1]
	typename = type_conv(bits[2])
	if classname == 'Show':
		print "Warning: discarding class instance '%s :: Show'" % typename
		return None;
	if typename == '()':
		print "Warning: discarding class instance 'unit :: %s'"\
			% classname
		return None
	if len(bits) == 3:
		if children == []:
			defs = []
		else:
			[(l, c)] = children
			assert l.strip() == 'where'
			defs = c
	else:
		assert bits[3:] == ['where']
		defs = children
	defs = [create_def_2 (l, c, 0) for (l, c) in defs]
	defs = [d2 for d2 in defs if d2 != None]
	defs = group_defs (defs)
	defs = [defs_transform (d2) for d2 in defs]
	defs_dict = {}
	for d2 in defs:
		if d2 != None:
			defs_dict[d2['defined']] = d2
	d['instance_defs'] = defs_dict
	d['deriving'] = [classname]
	if not typename in all_type_arities:
		sys.stderr.write('FAIL: attempting %s\n' % d['defined'])
		sys.stderr.write('(typename %r)\n' % typename)
		sys.stderr.write('when reading %s\n' % filename)
		sys.stderr.write('but class not defined yet\n')
		sys.stderr.write('perhaps parse in different order?\n')
		sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n')
		sys.exit(1)
	arities = all_type_arities[typename]
	set_instance_proofs (typename, arities, d)

	return d

all_type_arities = {}

def set_instance_proofs (header, constructor_arities, d):
	all_type_arities[header] = constructor_arities
	pfs = []
	exs = []
	canonical = list(enumerate(constructor_arities))

	classes = d.get('deriving', [])
	instance_proof_fns = [instance_proof_table[classname]
				for classname in classes]
	instance_proof_fns = dict([(f, 1) for f in instance_proof_fns]).keys()
	instance_proof_fns.sort (lambda x, y: cmp (x.order, y.order))
	for f in instance_proof_fns:
		(npfs, nexs) = f(header, canonical, d)
		pfs.extend (npfs)
		exs.extend (nexs)
	
	if d['type'] == 'newtype' and len(canonical) == 1 and False:
		[(cons, n)] = constructor_arities
		if n == 1:
			pfs.extend (finite_instance_proofs (header, cons))
	
	if 0: # serialising going away
		if d['type'] == 'newtype':
			(npfs, nexs) = serialisable_instance_proofs (header, canonical, d)
			pfs.extend (npfs)
			exs.extend (nexs)

	if pfs:
		lead = '(* %s instance proofs *)' % header
		d['instance_proofs'] = [(lead, [(line, []) for line in pfs])]
	if exs:
		lead = '(* %s extra instance defs *)' % header
		d['instance_extras'] = [(lead, [(line, []) for line in exs])]

def finite_instance_proofs (header, cons):
	lines = []
	lines.append ('')
	lines.append ('instance %s :: finite' % header)
	lines.append ('  apply (intro_classes)')
	lines.append ('  apply (rule_tac f="%s" in finite_surj_type)' \
						% cons)
	lines.append ('  apply (safe, case_tac x, simp_all)')
	lines.append ('  done')
	
	return (lines, [])

def serialisable_instance_proofs (header, canonical, d):
	lines = []
	lines.append ('')
	lines.append ('instance %s :: to_from_byte ..' % header)
	lines.append ('')
	lines.append ('defs (overloaded)')
	lines.append ('  to_byte_%s: "to_byte x \<equiv> case x of' % header)
	# a canonical ordering of the constructors.
	for i, (cons, n) in canonical:
		if i == 0:
			l = '    %s' % cons
		else:
			l = '  | %s' % cons
		for j in range (n):
			l = l + ' v%d' % j
		if len(canonical) > 1:
			l = l + ' \<Rightarrow> [%d]' % i
		else:
			l = l + ' \<Rightarrow> []'
		for j in range (n):
			l = l + ' @ (to_byte v%d)' % j

		lines.append (l)
	lines[-1] = lines[-1] + '"'
	loaders = []
	for i, (cons, n) in canonical:
		if n == 0:
			loaders.append (['   Some (%s, bs)' % cons])
			continue
		l1 = '   (loader'
		for _ in range (n - 1):
			l1 = l1 + ' -w- loader'
		l1 = l1 + ') bs'
		l2 = '    -\<longrightarrow> (\<lambda> (v0'
		for j in range (1, n):
			l2 = l2 + ', v%d' % j
		l2 = l2 + '). %s v0' % cons
		for j in range (1, n):
			l2 = l2 + ' v%d' % j
		l2 = l2 + ')'
		loaders.append ([l1, l2])
	if len(loaders) == 1:
		lines.append ('  loader_%s: "loader bs \<equiv>' % header)
		lines.extend (loaders[0])
		lines[-1] = lines[-1] + '"'
	else:
		lines.append ('  loader_%s: "loader bs_ \<equiv>' % header)
		lines.append ('  case bs_ of [] \<Rightarrow> None')
		lines.append ('| b#bs \<Rightarrow>')
		for i, _ in canonical:
			if i == 0:
				lines.append ('  if b = 0 then')
			else:
				lines.append ('  else if b = %d then' % i)
			lines.extend (loaders[i])
		lines.append ('  else None"')
	lines.append ('')
	lines.append ('lemmas serialisable_%s_defs =' % header) 
	lines.append ('  to_byte_%s loader_%s' % (header, header))
	lines.append ('')
	lines.append ('instance %s :: serialisable' % header)
	lines.append ('  apply intro_classes')
	lines.append ('  apply (case_tac w)')
	lines.append ('  apply (simp_all add: serialisable_%s_defs loader_simps)' \
			% header)
	lines.append ('  done')

	return (lines, [])

# leave type tags 0..11 for explicit use outside of this script
next_type_tag = 12
def storable_instance_proofs (header, canonical, d):
	proofs = []
	extradefs = []

	global next_type_tag
	next_type_tag = next_type_tag + 1
	proofs.extend ([
		'',
		'defs (overloaded)',
		'  typetag_%s[simp]:' % header,
		' "typetag (x :: %s) \<equiv> %d"' % (header, next_type_tag),
		''
		'instance %s :: dynamic' % header,
		'  by (intro_classes, simp)'
	])
	
	proofs.append ('')
	proofs.append ('instance %s :: storable ..' % header)

	defs = d.get('instance_defs', {})
	extradefs.append ('')
	if 'objBits' in defs:
		extradefs.append ('definition')
		body = flatten_tree (defs['objBits']['body'])
		bits = body[0].split('objBits')
		assert bits[0].strip() == '"'
		if bits[1].strip().startswith('_'):
			bits[1] = 'x ' + bits[1].strip()[1:]
		bits = bits[1].split(None, 1)
		body[0] = '  objBits_%s: "objBits (%s :: %s) %s' \
				% (header, bits[0], header, bits[1])
		extradefs.extend (body)

	extradefs.append ('')
	if 'makeObject' in defs:
		extradefs.append ('definition')
		body = flatten_tree (defs['makeObject']['body'])
		bits = body[0].split('makeObject')
		assert bits[0].strip() == '"'
		body[0] = '  makeObject_%s: "(makeObject :: %s) %s' \
				% (header, header, bits[1])
		extradefs.extend (body)
	
	extradefs.extend ([
		'',
		'definition',
	])
	if 'loadObject' in defs:
		extradefs.append ('  loadObject_%s:' % header)
		extradefs.extend (flatten_tree (defs['loadObject']['body']))
	else:
		extradefs.extend ([
			'  loadObject_%s[simp]:' % header,
			' "(loadObject p q n obj) :: %s \<equiv>' % header,
			'    loadObject_default p q n obj"',
		])

	extradefs.extend ([
		'',
		'definition',
	])
	if 'updateObject' in defs:
		extradefs.append ('  updateObject_%s:' % header)
		body = flatten_tree (defs['updateObject']['body'])
		bits = body[0].split ('updateObject')
		assert bits[0].strip() == '"'
		bits = bits[1].split(None, 1)
		body[0] = ' "updateObject (%s :: %s) %s' \
				% (bits[0], header, bits[1])
		extradefs.extend (body)
	else:
		extradefs.extend ([
			'  updateObject_%s[simp]:' % header,
			' "updateObject (val :: %s) \<equiv>' % header,
			'    updateObject_default val"',
		])


	return (proofs, extradefs)
storable_instance_proofs.order = 1

def pspace_storable_instance_proofs (header, canonical, d):
	proofs = []
	extradefs = []

	proofs.append ('')
	proofs.append ('instance %s :: pre_storable' % header)
	proofs.append ('  by (intro_classes,')
	proofs.append ('      auto simp: projectKO_opts_defs split: kernel_object.splits arch_kernel_object.splits)')

	defs = d.get('instance_defs', {})
	extradefs.append ('')
	if 'objBits' in defs:
		extradefs.append ('definition')
		body = flatten_tree (defs['objBits']['body'])
		bits = body[0].split('objBits')
		assert bits[0].strip() == '"'
		if bits[1].strip().startswith('_'):
			bits[1] = 'x ' + bits[1].strip()[1:]
		bits = bits[1].split(None, 1)
		body[0] = '  objBits_%s: "objBits (%s :: %s) %s' \
				% (header, bits[0], header, bits[1])
		extradefs.extend (body)

	extradefs.append ('')
	if 'makeObject' in defs:
		extradefs.append ('definition')
		body = flatten_tree (defs['makeObject']['body'])
		bits = body[0].split('makeObject')
		assert bits[0].strip() == '"'
		body[0] = '  makeObject_%s: "(makeObject :: %s) %s' \
				% (header, header, bits[1])
		extradefs.extend (body)
	
	extradefs.extend ([
		'',
		'definition',
	])
	if 'loadObject' in defs:
		extradefs.append ('  loadObject_%s:' % header)
		extradefs.extend (flatten_tree (defs['loadObject']['body']))
	else:
		extradefs.extend ([
			'  loadObject_%s[simp]:' % header,
			' "(loadObject p q n obj) :: %s kernel \<equiv>' % header,
			'    loadObject_default p q n obj"',
		])

	extradefs.extend ([
		'',
		'definition',
	])
	if 'updateObject' in defs:
		extradefs.append ('  updateObject_%s:' % header)
		body = flatten_tree (defs['updateObject']['body'])
		bits = body[0].split ('updateObject')
		assert bits[0].strip() == '"'
		bits = bits[1].split(None, 1)
		body[0] = ' "updateObject (%s :: %s) %s' \
				% (bits[0], header, bits[1])
		extradefs.extend (body)
	else:
		extradefs.extend ([
			'  updateObject_%s[simp]:' % header,
			' "updateObject (val :: %s) \<equiv>' % header,
			'    updateObject_default val"',
		])


	return (proofs, extradefs)
pspace_storable_instance_proofs.order = 1

def num_instance_proofs (header, canonical, d):
	assert len(canonical) == 1
	[(_, (cons, n))] = canonical 
	assert n == 1
	lines = []
	def add_bij_instance (classname, fns):
		ins = bij_instance (classname, header, cons, fns)
		lines.extend (ins)
	add_bij_instance ('plus', [('plus', '%s + %s', True)])
	add_bij_instance ('minus', [('minus', '%s - %s', True)])
	add_bij_instance ('zero', [('zero', '0', True)])
	add_bij_instance ('one', [('one', '1', True)])
	add_bij_instance ('times', [('times', '%s * %s', True)])
	
	return (lines, [])
num_instance_proofs.order = 2

def enum_instance_proofs (header, canonical, d):
	lines = ['(*<*)']
	if len(canonical) == 1:
		[(_, (cons, n))] = canonical 
		assert n == 1
		lines.append ('instantiation %s :: enum begin' % header)
		lines.append ('definition')
		lines.append ('  enum_%s: "enum_class.enum \<equiv> map %s enum"' \
					% (header, cons))
	
	else:
		cons_no_args = [cons for i, (cons, n) in canonical if n == 0]
		cons_one_arg = [cons for i, (cons, n) in canonical if n == 1]
		cons_two_args = [cons for i, (cons, n) in canonical if n > 1]
		assert cons_two_args == []
		lines.append ('instantiation %s :: enum begin' % header)
		lines.append ('definition')
		lines.append ('  enum_%s: "enum_class.enum \<equiv> ' % header)
		lines.append ('    [ ')
		for cons in cons_no_args[:-1]:
			lines.append ('      %s,' % cons)
		for cons in cons_no_args[-1:]:
			lines.append ('      %s' % cons)
		lines.append ('    ]')
		for cons in cons_one_arg:
			lines.append ('    @ (map %s enum)' % cons)
		lines[-1] = lines[-1] + '"'
	lines.append ('')
	lines.append ('definition')
	lines.append ('  "enum_class.enum_all (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Ball UNIV P"' \
				  % header)
	lines.append ('')
	lines.append ('definition')
	lines.append ('  "enum_class.enum_ex (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Bex UNIV P"' \
				  % header)
	lines.append ('')
	lines.append ('  instance')
	lines.append ('  apply intro_classes')
	lines.append ('   apply (safe, simp)')
	lines.append ('   apply (case_tac x)')
	if len(canonical) == 1:
		lines.append ('  apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def' \
					  % (header, header, header))
		lines.append ('    distinct_map_enum)')
	else:
		lines.append ('  apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def)' \
					  % (header, header, header))
		lines.append ('  apply fast')
	lines.append ('  done')
	lines.append ('end')
	lines.append ('')
	lines.append ('instantiation %s :: enum_alt' % header)
	lines.append ('begin')
	lines.append ('definition')
	lines.append ('  enum_alt_%s: "enum_alt \<equiv> ' % header)
	lines.append ('    alt_from_ord (enum :: %s list)"' % header)
	lines.append ('instance ..')
	lines.append ('end')
	lines.append ('')
	lines.append ('instantiation %s :: enumeration_both' % header)
	lines.append ('begin')
	lines.append ('instance by (intro_classes, simp add: enum_alt_%s)' \
			      % header)
	lines.append ('end')
	lines.append ('')
	lines.append ('(*>*)')
	
	return (lines, [])
enum_instance_proofs.order = 3

def bits_instance_proofs (header, canonical, d):
	assert len(canonical) == 1
	[(_, (cons, n))] = canonical 
	assert n == 1

	return (bij_instance ('bit', header, cons,
				[('shiftL', 'shiftL %s n', True),
				 ('shiftR', 'shiftR %s n', True),
				 ('bitSize', 'bitSize %s', False)]),
			[])
bits_instance_proofs.order = 5

def no_proofs (header, canonical, d):
	return ([], [])
no_proofs.order = 6

instance_proof_table = {
	'Eq': no_proofs,
	'Bounded': enum_instance_proofs,
	'Enum': enum_instance_proofs,
	'Ix': no_proofs,
	'Ord': no_proofs,
	'Show': no_proofs,
	'Bits': bits_instance_proofs,
	'Real': no_proofs,
	'Num': num_instance_proofs,
	'Integral': no_proofs,
	'Storable': storable_instance_proofs,
	'PSpaceStorable': pspace_storable_instance_proofs,
	'Typeable': no_proofs,
	'Error': no_proofs,
	}

def bij_instance (classname, typename, constructor, fns):
	L = [
		'',
		'instance %s :: %s ..' % (typename, classname),
		'defs (overloaded)'
	]
	for (fname, fstr, cast_return) in fns:
		n = len (fstr.split('%s')) - 1
		names = ('x', 'y', 'z', 'w')[:n]
		names2 = tuple ([name + "'" for name in names])
		fstr1 = fstr % names
		fstr2 = fstr % names2
		L.append ('  %s_%s: "%s \<equiv>' % (fname, typename, fstr1))
		for name in names:
			L.append ("    case %s of %s %s' \<Rightarrow>" \
				% (name, constructor, name))
		if cast_return:
			L.append ('      %s (%s)"' % (constructor, fstr2))
		else:
			L.append ('      %s"' % fstr2)
	
	return L

def get_type_map (string):
	"""Takes Haskell named record syntax and produces
	a map (in the form of a list of tuples) defining
	the converted types of the names."""
	bits = string.split (None, 1)
	header = bits[0].strip()
	if len(bits) == 1:
		return (header, [])
	actual_map = bits[1].strip()
	if not (actual_map.startswith('{') and actual_map.endswith('}')):
		print 'Error in ' + filename
		print 'Expected "A { blah :: blah etc }"'
		print 'However { and } not found correctly'
		print 'When parsing %s' % string
		sys.exit (1)
	actual_map = actual_map[1:-1]

	bits = braces.str(actual_map, '(', ')').split (',')
	bits.reverse()
	type = None
	map = []
	for bit in bits:
		bits = bit.split ('::')
		if len(bits) == 2:
			type = type_transform (str(bits[1]).strip())
			name = str(bits[0]).strip()
		else:
			name = str(bit).strip()
		map.append((name, type))
	map.reverse()
	return (header, map)

numLiftIO = [0]

def body_transform (body, defined, sig, nopattern=False):
	# assume single object
	[(line, children)] = body

	if '(' in line.split('=')[0] and not nopattern:
		[(line, children)] = \
			pattern_match_transform ([(line, children)])

	if 'liftIO' in reduce_to_single_line ((line, children)):
		# liftIO is the translation boundary for current
		# SEL4, below which we get into details of interaction
		# with the foreign function interface - axiomatise
		assert '=' in line
		line = line.split ('=')[0]
		bits = line.split()
		numLiftIO[0] = numLiftIO[0] + 1
		rhs = '(%d :: Int) %s' % (numLiftIO[0], ' '.join(bits[1:]))
		line = '%s\<equiv> underlying_arch_op %s' % (line, rhs)
		children = []
	elif '=' in line:
		line = '\<equiv>'.join (line.split ('=', 1))
	elif leading_bar.match (children[0][0]):
		pass
	elif '=' in children[0][0]:
		(nextline, c2) = children[0]
		children[0] = ('\<equiv>'.join (nextline.split ('=', 1)), c2)
	else:
		sys.stderr.write ('WARN: def of %s missing =\n' % defined)

	if children and (children[-1][0].endswith ('where')
			or children[-1][0].lstrip().startswith('where')):
		bits = line.split ('\<equiv>')
		where_clause = where_clause_transform (children[-1])
		children = children[:-1]
		if len(bits) == 2 and bits[1].strip():
			line = bits[0] + '\<equiv>'
			new_line = ' ' * len(line) + bits[1]
			children = [(new_line, children)]
	else:
		where_clause = []
	
	(line, children) = zipWith_transforms (line, children)

	(line, children) = supplied_transforms (line, children)
	
	(line, children) = case_clauses_transform ((line, children))

	(line, children) = do_clauses_transform ((line, children), sig)

	if children and leading_bar.match (children[0][0]):
		line = line + ' \<equiv>'
		children = guarded_body_transform (children, ' = ')
	
	children = where_clause + children

	if not nopattern:
		line = lhs_transform (line)
	line = lhs_de_underscore (line)

	(line, children) = type_assertion_transform (line, children)

	(line, children) = run_regexes ((line, children))
	(line, children) = run_ext_regexes ((line, children))

	(line, children) = bracket_dollar_lambdas ((line, children))

	line = '"' + line
	(line, children) = add_trailing_string ('"', (line, children))
	return [(line, children)]

dollar_lambda_regex = re.compile (r"\$\s*\\<lambda>")

def bracket_dollar_lambdas ((line, children)):
	if dollar_lambda_regex.search(line):
		[left, right] = dollar_lambda_regex.split(line)
		line = '%s(\<lambda>%s' % (left, right)
		both = (line, children)
		if has_trailing_string(';', both):
			both = remove_trailing_string(';', both)
			(line, children) = add_trailing_string (');', both)
		else:
			(line, children) = add_trailing_string (')', both)
	children = [bracket_dollar_lambdas (elt) for elt in children]
	return (line, children)

def zipWith_transforms (line, children):
	if not 'zipWithM_' in line:
		children = [zipWith_transforms (l, c) for (l, c) in children]
		return (line, children)
	
	if children == []:
		return (line, [])

	if len(children) == 2:
		[(l, c), (l2, c2)] = children
		if c == [] and c2 == [] and '..]' in l + l2:
			children = [(l + ' ' + l2.strip(), [])]

	(l, c) = children[-1]
	if c != [] or '..]' not in l:
		return (line, children)
	
	bits = line.split('zipWithM_', 1)
	line = bits[0] + 'mapM_'
	ws = lead_ws (l)
	line2 = ws + '(split ' + bits[1]

	bits = braces.str(l, '[', ']').split(None, braces=True)
	
	line3 = ws + ' '.join(bits[:-2]) + ')'
	used_children = children[:-1] + [(line3, [])]

	sndlast = bits[-2]
	last = bits[-1]
	if last.endswith ('..]'):
		internal = last[1:-3].strip()
		if ',' in internal:
			bits = internal.split(',')
			l = '%s(zipE4 (%s) (%s) (%s))' \
				% (ws, sndlast, bits[0], bits[-1])
		else:
			l = '%s(zipE3 (%s) (%s))' % (ws, sndlast, internal)
	else:
		internal = sndlast[1:-3].strip()
		if ',' in internal:
			bits = internal.split(',')
			l = '%s(zipE2 (%s) (%s) (%s))' \
				% (ws, bits[0], bits[1], last)
		else:
			l = '%s(zipE1 (%s) (%s))' % (ws, internal, last)
	
	return (line, [(line2, used_children), (l, [])])

def supplied_transforms (line, children):
	t = convert_to_stripped_tuple (line, children)

	if t in supplied_transform_table:
		ws1 = lead_ws (line)
		result = supplied_transform_table[t]
		ws2 = lead_ws (result[0])
		result = adjust_ws (result, len(ws1) - len(ws2))
		supplied_transforms_usage[t] = 1
		return result
	
	children = [supplied_transforms (l, c) for (l, c) in children]

	return (line, children)

def convert_to_stripped_tuple (line, children):
	children = [convert_to_stripped_tuple (l, c) for (l, c) in children]

	return (line.strip(), tuple(children))

def type_assertion_transform_inner (line):
	m = type_assertion.search (line)
	if not m:
		return line
	var = m.expand ('\\1')
	type = m.expand ('\\2').strip()
	type = type_transform (type)
	return line[:m.start()] + '(%s::%s)' % (var, type) \
		+ type_assertion_transform_inner (line[m.end():])

def type_assertion_transform (line, children):
	children = [type_assertion_transform (l, c) for (l, c) in children]

	return (type_assertion_transform_inner (line), children)

def where_clause_guarded_body ((line, children)):
	if children and leading_bar.match(children[0][0]):
		return (line + ' =', guarded_body_transform (children, ' = '))
	else:
		return (line, children)


def where_clause_transform ((line, children)):
	ws = line.split('where', 1)[0]
	if line.strip() != 'where':
		assert line.strip().startswith('where')
		children = [(''.join (line.split('where', 1)), [])] + children
	pre = ws + 'let'
	post = ws + 'in'
	
	children = [(l, c) for (l, c) in children if l.split()[1] != '::']
	children = [case_clauses_transform ((l, c)) for (l, c) in children]
	children = [do_clauses_transform ((l, c), None, type=0)
			for (l, c) in children]
	children = map (where_clause_guarded_body, children)
	for i, (l, c) in enumerate (children):
		l2 = braces.str(l, '(', ')')
		if len(l2.split('=')[0].split()) > 1:
			# turn f a = b into f = (\a -> b)
			l = '->'.join(l.split('=', 1))
			l = lead_ws(l) + ' = (\\ '.join(l.split(None, 1))
			(l, c) = add_trailing_string (')', (l, c))
			children[i] = (l, c)
	children = order_let_children (children)
	for i, child in enumerate (children[:-1]):
		children[i] = add_trailing_string (';', child)
	return [(pre, [])] + children + [(post, [])]

varname_re = re.compile (r"\w+")

def order_let_children (L):
	single_lines = [reduce_to_single_line (elt) for elt in L]
	keys = [ str(braces.str(line, '(', ')').split('=')[0]).split()[0]
		for line in single_lines]
	keys = dict ([(key, n) for (n, key) in enumerate (keys)])
	bits = [varname_re.findall (line) for line in single_lines]
	deps = {}
	for i, bs in enumerate(bits):
		for bit in bs:
			if bit in keys:
				j = keys[bit]
				if j != i:
					deps.setdefault(i, {})[j] = 1
	done = {}
	L2 = []
	todo = dict (enumerate (L))
	n = len(todo)
	while n:
		todo_keys = todo.keys()
		for key in todo_keys:
			depstodo = [dep for dep
					in deps.get(key, {}).keys()
					if dep not in done]
			if depstodo == []:
				L2.append (todo.pop(key))
				done[key] = 1
		if len(todo) == n:
			print "No progress resolving let deps"
			print
			print todo.values()
			print
			print "Dependencies:"
			print deps
			assert 0
		n = len(todo)
	return L2

def do_clauses_transform ((line, children), rawsig, type = None):
	if children and children[-1][0].lstrip().startswith('where'):
		where_clause = where_clause_transform (children[-1])
		where_clause = [do_clauses_transform ((l, c), rawsig, 0)
					for (l, c) in where_clause]
		others = (line, children[:-1])
		others = do_clauses_transform (others, rawsig, type)
		(line, children) = where_clause[0]
		return (line, children + where_clause[1:] + [others])

	if children:
		(l, c) = children[0]
		if c == [] and l.endswith('do'):
			line = line + ' ' + l.strip()
			children = children[1:]

	if type == None:
		if not rawsig:
			type = 0
			sig = None
		else:
			sig = flatten_tree([rawsig])
			sig = ' '.join(sig)
			type = monad_type_acquire (sig)
	(line, type) = monad_type_transform ((line, type))
	if children == []:
		return (line, [])

	rhs = line.split('<-', 1)[-1]
	if rhs.strip() == 'syscall' or rhs.strip() == 'atomicSyscall':
		assert len(children) == 5
		children = [do_clauses_transform (elt, rawsig, type = subtype)
				for elt, subtype in
				zip (children, [1, 0, 1, 0, type])]
	elif line.strip().endswith('catchFailure'):
		assert len(children) == 2
		children = [do_clauses_transform (elt, rawsig, type = subtype)
				for elt, subtype in
				zip (children, [1, 0])]
	else:
		children = [do_clauses_transform (elt, rawsig, type = type)
					for elt in children]
	
	if not line.endswith('do'):
		return (line, children)
	
	children, other_children = split_on_unmatched_bracket (children)
	
	# single statement do clause won't parse in Isabelle
	if len(children) == 1:
		ws = lead_ws (line)
		return (line[:-2] + '(', children + [(ws + ')', [])])

	line = line[:-2] + '(do' + 'E' * type

	children = [(l, c) for (l, c) in children if l.strip() or c]

	children2 = []
	for i, (l, c) in enumerate (children):
		if l.lstrip().startswith('let '):
			if '=' not in l:
				extra = reduce_to_single_line (c[0])
				assert '=' in extra
				l = l + ' ' + extra
				c = c[1:]
			l = ''.join (l.split ('let ', 1))
			letsubs = '<- return' + 'Ok' * type + ' ('
			l = letsubs.join (l.split ('=', 1))
			(l, c) = add_trailing_string (')', (l, c))
			children2.extend(do_clause_pattern (l, c, type))
		else:
			children2.extend (do_clause_pattern (l, c, type))

	children = [add_trailing_string (';', child)
			for child in children2[:-1]] + [children2[-1]]

	ws = lead_ws (line)
	children.append ((ws + 'od' + 'E' * type + ')', []))

	return (line, children + other_children)

left_start_table = {
	'ASIDPool' : '(inv ASIDPool)',
	'HardwareASID' : 'id',
	'ArchObjectCap' : 'capCap',
	'Just' : 'the'
}

def do_clause_pattern (line, children, type, n = 0):
	bits = line.split('<-')
	default = [('\<leftarrow>'.join(bits), children)]
	if len(bits) == 1:
		return default
	(left, right) = line.split('<-', 1)
	if ':' not in left and '[' not in left \
		and len(left.split()) == 1:
		return default
	left = left.strip()

	v = 'v%d' % get_next_unique_id()

	ass = 'assert' + ('E' * type)
	ws = lead_ws (line)

	if left.startswith ('('):
		assert left.endswith (')')
		if (',' in left):
			return default
		else:
			left = left[1:-1]
	bs = braces.str (left, '[', ']')
	if len(bs.split (':')) > 1:
		bits = [str(bit).strip() for bit in bs.split(':', 1)]
		lines = [('%s%s <- %s' % (ws, v, right), children),
			('%s%s <- headM %s' % (ws, bits[0], v), []),
			('%s%s <- tailM %s' % (ws, bits[1], v), [])]
		result = []
		for (l, c) in lines:
			result.extend (do_clause_pattern (l, c, type, n + 1))
		return result
	if left == '[]':
		return [('%s%s <- %s' % (ws, v, right), children),
			('%s%s (%s = []) []' % (ws, ass, v), [])]
	if left.startswith ('['):
		assert left.endswith (']')
		bs = braces.str (left[1:-1], '[', ']')
		bits = bs.split(',', 1)
		if len(bits) == 1:
			new_left = '%s:%s' % (bits[0], v)
			new_line = '%s%s <- %s' % (ws, new_left, right)
			extra = [('%s%s (%s = []) []' % (ws, ass, v), [])]
		else:
			new_left = '%s:[%s]' % (bits[0], bits[1])
			new_line = lead_ws (line) + new_left + ' <- ' + right
			extra = []
		return do_clause_pattern (new_line, children, type, n + 1) \
			+ extra
	for lhs in left_start_table:
		if left.startswith (lhs):
			left = left[len(lhs):]
			tab = left_start_table[lhs]
			lM = 'liftM' + 'E' * type
			nl = ('%s <- %s %s $ %s' % (left, lM, tab, right))
			return do_clause_pattern (nl, children, type, n + 1)

	print line
	print left
	assert 0

def split_on_unmatched_bracket (elts, n = None):
	if n == None:
		elts, other_elts, n = split_on_unmatched_bracket (elts, 0)
		return (elts, other_elts)
	
	for (i, (line, children)) in enumerate (elts):
		for (j, c) in enumerate (line):
			if c == '(':
				n = n + 1
			if c == ')':
				n = n - 1
				if n < 0:
					frag1 = line[:j]
					frag2 = ' ' * len(frag1) + line[j:]
					new_elts = elts[:i] + [(frag1, [])]
					oth_elts = [(frag2, children)] \
						+ elts[i + 1:]
					return (new_elts, oth_elts, n)
		c, other_c, n = split_on_unmatched_bracket (children, n)
		if other_c:
			new_elts = elts[:i] + [(line, c)]
			other_elts = other_c + elts[i + 1:]
			return (new_elts, other_elts, n)
	
	return (elts, [], n)

def monad_type_acquire (sig, type=0):
	# note kernel appears after kernel_f/kernel_monad            
	for (key, n) in [('kernel_f', 1), ('fault_monad', 1),
			('syscall_monad', 2), ('kernel_monad', 0),('kernel_init',1),
			('kernel_p', 1), ('kernel', 0)]:
		if key in sig:
			sigend = sig.split(key)[-1]
			return monad_type_acquire (sigend, n)
	
	return type

def monad_type_transform ((line, type)):
	split = None
	if 'withoutError' in line:
		split = 'withoutError'
		newtype = 1
	elif 'doKernelOp' in line:
		split = 'doKernelOp'
		newtype = 0
	elif 'runInit' in line:
		split = 'runInit'
		newtype = 1 
	elif 'withoutFailure' in line:
		split = 'withoutFailure'
		newtype = 0
	elif 'withoutFault' in line:
		split = 'withoutFault'
		newtype = 0
	elif 'withoutPreemption' in line:
		split = 'withoutPreemption'
		newtype = 0
	elif 'allowingFaults' in line:
		split = 'allowingFaults'
		newtype = 1
	elif 'allowingErrors' in line:
		split = 'allowingErrors'
		newtype = 2
	elif '`catchFailure`' in line:
		[left, right] = line.split ('`catchFailure`', 1)
		left, _ = monad_type_transform ((left, 1))
		right, type = monad_type_transform ((right, 0))
		return (left + '`catchFailure`' + right, type)
	elif 'catchingFailure' in line:
		split = 'catchingFailure'
		newtype = 1
	elif 'catchF' in line:
		split = 'catchF'
		newtype = 1
	elif 'emptyOnFailure' in line:
		split = 'emptyOnFailure'
		newtype = 1
	elif 'constOnFailure' in line:
		split = 'constOnFailure'
		newtype = 1
	elif 'nothingOnFailure' in line:
		split = 'nothingOnFailure'
		newtype = 1
	elif 'nullCapOnFailure' in line:
		split = 'nullCapOnFailure'
		newtype = 1
	elif '`catchFault`' in line:
		split = '`catchFault`'
		newtype = 1
	elif 'capFaultOnFailure' in line:
		split = 'capFaultOnFailure'
		newtype = 1
	elif 'ignoreFailure' in line:
		split = 'ignoreFailure'
		newtype = 1
	elif 'handleInvocation False' in line: # THIS IS A HACK
		split = 'handleInvocation False'
		newtype = 0
	if split:
		[left, right] = line.split (split, 1)
		left, _ = monad_type_transform ((left, type))
		right, newnewtype = monad_type_transform ((right, newtype))
		return (left + split + right, newnewtype)

	if type:
		line = ('return' + 'Ok' * type).join (line.split('return'))
		line = ('when' + 'E' * type).join (line.split('when'))
		line = ('unless' + 'E' * type).join (line.split('unless'))
		line = ('mapM' + 'E' * type).join (line.split('mapM'))
		line = ('forM' + 'E' * type).join (line.split('forM'))
		line = ('liftM' + 'E' * type).join (line.split('liftM'))
		line = ('assert' + 'E' * type).join (line.split('assert'))
		line = ('stateAssert' + 'E' * type).join (line.split('stateAssert'))
	
	return (line, type)

def case_clauses_transform ((line, children)):
	children = [case_clauses_transform (child)
		for child in children]

	if not line.endswith (' of'):
		return (line, children)

	bits = line.split ('case ', 1)
	beforecase = bits[0]
	x = bits[1][:-3]

	if '::' in x:
		x2 = braces.str(x, '(', ')')
		bits = x2.split('::', 1)
		if len(bits) == 2:
			x = str(bits[0]) + ':: ' + type_transform (str(bits[1]))

	if children and children[-1][0].strip().startswith('where'):
		sys.stderr.write ('Warning: where clause in case: %r\n' \
					% line)
		return (beforecase + '(* case removed *) undefined', [])
		where_clause = where_clause_transform (children[-1])
		children = children [:-1]
		(in_stmt, l) = where_clause[-1]
		l.append (case_clauses_transform ((line, children)))
		return where_clause

	cases = []
	bodies = []
	for (l, c) in children:
		bits = l.split ('->', 1)
		while len(bits) == 1:
			if c == []:
				sys.stderr.write ('wtf %r\n' % l)
				sys.exit(1)
			if c[0][0].strip().startswith ('|'):
				bits = [bits[0], '']
				c = guarded_body_transform (c, '->')
			elif c[0][1] == []:
				l = l + ' ' + c.pop(0)[0].strip()
				bits = l.split ('->', 1)
			else:
				[(moreline, c)] = c
				l = l + ' ' + moreline.strip()
				bits = l.split ('->', 1)
		case = bits[0].strip()
		tail = bits[1]
		if c and c[-1][0].lstrip().startswith('where'):
			where_clause = where_clause_transform (c[-1])
			ws = lead_ws(where_clause[0][0])
			c = where_clause + [(ws + tail.strip(), [])] + c[:-1]
			tail = ''
		cases.append (case)
		bodies.append ((tail, c))

	cases = tuple(cases)
	if not cases:
		print line
	conv = get_case_conv(cases)
	if conv == '<X>':
		sys.stderr.write ('Warning: blanked case in caseconvs\n')
		return (beforecase + '(* case removed *) undefined', [])
	if not conv:
		sys.stderr.write ('Warning: case %r\n' % (cases,))
		if cases not in cases_added:
			casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> '
			
			f = open ('caseconvs', 'a')
			f.write ('%s ---X>\n\n' % casestr)
			f.close ()
			cases_added[cases] = 1
		return (beforecase + '(* case removed *) undefined', [])
	conv = subs_nums_and_x (conv, x)

	new_line = beforecase + '(' + conv[0][0]
	assert conv[0][1] == None

	ws = lead_ws (children[0][0])
	new_children = []
	for (header, endnum) in conv[1:]:
		if endnum == None:
			new_children.append ((ws + header, []))
		else:	
			if (len(bodies) <= endnum):
			   sys.stderr.write('ERROR: index %d out of bounds in case %r\n' % (endnum, cases,))
   		           sys.exit(1);
			   			    
			(body, c) = bodies[endnum]
			new_children.append ((ws + header + ' ' + body, c))
	
	if has_trailing_string (',', new_children[-1]):
		new_children[-1] = \
			remove_trailing_string (',', new_children[-1])
		new_children.append ((ws + '),', []))
	else:
		new_children.append ((ws + ')', []))

	return (new_line, new_children)


def guarded_body_transform (body, div):
	new_body = []
	if body[-1][0].strip().startswith ('where'):
		new_body.extend(where_clause_transform(body[-1]))
		body = body[:-1]
	for i, (line, children) in enumerate(body):
		try:
			while div not in line:
				(l, c) = children[0]
				children = c + children[1:]
				line = line + ' ' + l.strip()
		except:
			sys.stderr.write ('missing %r in %r\n' % (div, line))
			sys.stderr.write ('\nhandling %r\n' % body)
			sys.exit(1)
		try:
			[ws, guts] = line.split ('| ', 1)
		except:
			sys.stderr.write ('missing "|" in %r\n' % line)
			sys.stderr.write ('\nhandling %r\n' % body)
			sys.exit(1)
		if i == 0:
			new_body.append ((ws + 'if', []))
		else:
			new_body.append ((ws + 'else if', []))
		guts = ' then '.join(guts.split (div, 1))
		new_body.append ((ws + guts, children))
	new_body.append ((ws + 'else undefined', []))

	return new_body

def lhs_transform (line):
	if not '(' in line:
		return line

	[left, right] = line.split ('\<equiv>')

	ws = left[:len(left) - len(left.lstrip())]

	left = left.lstrip()

	bits = braces.str(left, '(', ')').split(braces = True)

	for (i, bit) in enumerate(bits):
		if bit.startswith('('):
			bits[i] = 'arg%d' % i
			case = 'case arg%d of %s \<Rightarrow> ' % (i, bit)
			right = case + right
	
	return ws + ' '.join([str(bit) for bit in bits]) + '\<equiv>' + right

def lhs_de_underscore (line):
	if not '_' in line:
		return line

	[left, right] = line.split ('\<equiv>')

	ws = left[:len(left) - len(left.lstrip())]

	left = left.lstrip()
	bits = left.split()

	for (i, bit) in enumerate(bits):
		if bit == '_':
			bits[i] = 'arg%d' % i

	return ws + ' '.join([str(bit) for bit in bits]) + ' \<equiv>' + right
	

regexes = [
	(re.compile (r" \. "), r" \<circ> "),
	(re.compile ('-1'), '- 1'),
	(re.compile ('-2'), '- 2'),
	(re.compile (r"\(!(\w+)\)"), r"(flip id \1)"),
	(re.compile (r"\(\+(\w+)\)"), r"(\<lambda> x. x + \1)"),
	(re.compile (r"\\([^<].*?)\s*->"), r"\<lambda> \1."),
	(re.compile ('}'), r"\<rparr>"),
	(re.compile (r"(\s)!!(\s)"), r"\1LIST_INDEX\2"),
	(re.compile (r"(\w)!"), r"\1 "),
	(re.compile (r"\s?!"), ''),
	(re.compile (r"LIST_INDEX"), r"!"),
	(re.compile ('`testBit`'), '!!'),
	(re.compile (r"//"), ' aLU '),
	(re.compile ('otherwise'), 'True     '),
	(re.compile (r"(^|\W)fail "), r"\1haskell_fail "),
	(re.compile ('assert '), 'haskell_assert '),
	(re.compile ('assertE '), 'haskell_assertE '),
	(re.compile ('=='), '='),
	(re.compile (r"\(/="), '(\<lambda>x. x \<noteq>'),
	(re.compile ('/='), '\<noteq>'),
	(re.compile ('"([^"])*"'), '[]'),
	(re.compile ('&&'), '\<and>'),
	(re.compile ('\|\|'), '\<or>'),
	(re.compile (r"(\W)not(\s)"), r"\1Not\2"),
	(re.compile (r"(\W)and(\s)"), r"\1andList\2"),
	(re.compile (r"(\W)or(\s)"), r"\1orList\2"),
	# regex ordering important here
	(re.compile (r"\.&\."), '&&'),
	(re.compile (r"\(\.\|\.\)"), r"bitOR"),
	(re.compile (r"\(\+\)"), r"op +"),
	(re.compile (r"\.\|\."), '||'),
	(re.compile (r"Data\.Word\.Word"), r"word"),
	(re.compile (r"Data\.Map\."), r"data_map_"),
	(re.compile (r"BinaryTree\."), 'bt_'),
	(re.compile ("mapM_"), "mapM_x"),
	(re.compile ("forM_"), "forM_x"),
	(re.compile ("forME_"), "forME_x"),
	(re.compile ("mapME_"), "mapME_x"),
	(re.compile ("zipWithM_"), "zipWithM_x"),
	(re.compile (r"bit\s+([0-9]+)(\s)"), r"(1 << \1)\2"),
	(re.compile ('`mod`'), 'mod'),
	(re.compile ('`div`'), 'div'),
	(re.compile (r"`((\w|\.)*)`"), r"`~\1~`"),
	(re.compile ('size'), 'magnitude'),
	(re.compile ('foldr'), 'foldR'),
	(re.compile (r"\+\+"), '@'),
	(re.compile (r"(\s|\)|\w|\]):(\s|\(|\w|$|\[)"), r"\1#\2"),
	(re.compile (r"\[([^]]*)\.\.([^]]*)\]"), r"[\1 .e. \2]"),
	(re.compile (r"\[([^]]*)\.\.\s*$"), r"[\1 .e."),
	(re.compile (' Right'), ' Inr'),
	(re.compile (' Left'), ' Inl'),
	(re.compile ('\\(Left'), '(Inl'),
	(re.compile ('\\(Right'), '(Inr'),
	(re.compile (r"\$!"), r"$"),
	(re.compile ('([^>])>='), r'\1\<ge>'),
	(re.compile ('<='), '\<le>'),
	(re.compile (r" \\\\ "), " `~listSubtract~` "),
	(re.compile (r"(\s\w+)\s*@\s*\w+\s*{\s*}\s*\<leftarrow>"),
		r"\1 \<leftarrow>"),
]

ext_regexes = [
	(re.compile (r"(\s[A-Z]\w*)\s*{"), r"\1_ \<lparr>", 
		re.compile (r"(\w)\s*="), r"\1="),
	(re.compile (r"(\([A-Z]\w*)\s*{"), r"\1_ \<lparr>", 
		re.compile (r"(\w)\s*="), r"\1="),
	(re.compile (r"{([^={<]*[^={<:])=([^={<]*)\\<rparr>"),
		r"\<lparr>\1:=\2\<rparr>",
		re.compile (r"THIS SHOULD NOT APPEAR IN THE SOURCE"), ""),
	(re.compile (r"{"), r"\<lparr>",
		re.compile (r"([^=:])=(\s|$|\w)"), r"\1:=\2"),
]

leading_bar = re.compile (r"\s*\|")

type_assertion = re.compile (r"\(([^(]*)::([^)]*)\)")

def run_regexes ((line, children), _regexes = regexes):
	for re, s in _regexes:
		line = re.sub(s, line)
	children = [run_regexes (elt, _regexes = _regexes)
			for elt in children]
	return ((line, children))

def run_ext_regexes ((line, children)):
	for re, s, add_re, add_s in ext_regexes:
		m = re.search(line)
		if m == None:
			continue
		before = line[:m.start()]
		substituted = m.expand(s)
		after = line[m.end():]
		add = [(add_re, add_s)]
		(after, children) = run_regexes ((after, children),
						_regexes = add)
		line = before + substituted + after
	children = [run_ext_regexes (elt) for elt in children]
	return (line, children)

def get_case_lhs (lhs):
	assert lhs.startswith ('case \\x of ')
	lhs = lhs.split('case \\x of ', 1)[1]
	cases = lhs.split ('->')
	cases = [case.strip() for case in cases]
	cases = [case for case in cases if case != '']
	cases = tuple(cases)

	return cases

def get_case_rhs (rhs):
	tuples = []
	while '->' in rhs:
		bits = rhs.split ('->', 1)
		s = bits[0]
		bits = bits[1].split(None, 1)		
		n = int(takeWhile (bits[0], lambda x: x.isdigit())) - 1
		if len(bits) > 1:
			rhs = bits[1]
		else:
			rhs = ''
		tuples.append ((s, n))
	if rhs != '':
		tuples.append ((rhs, None))

	conv = []
	for (string, num) in tuples:
		bits = string.split ('\\n')
		bits = [bit.strip() for bit in bits]
		conv.extend ([(bit, None) for bit in bits[:-1]])
		conv.append ((bits[-1], num))

	conv = [(s, n) for (s, n) in conv if s != '' or n != None]

	if conv[0][1] != None:
		sys.stderr.write ('%r\n' % conv[0][1])
		sys.stderr.write ('For technical reasons the first line of this case conversion must be split with \\n: \n')
		sys.stderr.write ('  %r\n' % rhs)
                sys.stderr.write ('(further notes: the rhs of each caseconv must have multiple lines\n'
                        'and the first cannot contain any ->1, ->2 etc.)\n')
		sys.exit(1)

	# this is a tad dodgy, but means that case_clauses_transform
	# can be safely run twice on the same input
	if conv[0][0].endswith('of'):
		conv[0] = (conv[0][0] + ' ', conv[0][1])
	
	return conv

def render_caseconv (cases, conv, f):
	bits = [bit.rstrip() for bit in conv.split('\\n')]
	bits = [bit for bit in bits if bit != '']
	assert bits
	casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> '
	f.write ('%s --->' % casestr)
	for bit in bits:
		f.write (bit)
		f.write ('\n')
	f.write('\n')

def get_case_conv_table ():
	f = open ('caseconvs')
	f2 = open ('caseconvs-useful', 'w')
	result = {}
	input = [line.rstrip() for line in f]
	input = ["\\n".join(lines) for lines in splitList (input, emptyList)]

	for line in input:
	  if line.strip() == '':
		continue
	  try:
		if '---X>' in line:
			[from_case, blah] = line.split ('---X>')
			cases = get_case_lhs (from_case)
			result[cases] = "<X>"
		else:
			[from_case, to_case] = line.split ('--->')
			cases = get_case_lhs (from_case)
			conv = get_case_rhs (to_case)
			result[cases] = conv
			if (not all_constructor_patterns(cases) and
					not is_extended_pattern(cases)):
				render_caseconv (cases, to_case, f2) 
	  except Exception, e:
	  	sys.stderr.write ('Error parsing %r\n' % line)
		sys.stderr.write ('%s\n ' % e)
		sys.exit (1)

	f.close()
	f2.close()

	return result
	
def all_constructor_patterns (cases):
	if cases[-1].strip() == '_':
		cases = cases[:-1]
	for pat in cases:
		if not is_constructor_pattern (pat):
			return False
	return True

def is_constructor_pattern (pat):
	"""A constructor pattern takes the form Cons var1 var2 ...,
	characterised by all alphanumeric names, the constructor starting
	with an uppercase alphabetic char and the vars with lowercase."""
	bits = pat.split()
	for bit in bits:
		if (not bit.isalnum()) and (not bit == '_'):
			return False
	if not bits[0][0].isupper():
		return False
	for bit in bits[1:]:
		if (not bit[0].islower()) and (not bit == '_'):
			return False
	return True

ext_checker = re.compile (r"^(\(|\)|,|{|}|=|[a-zA-Z][0-9']?|\s|_|:|\[|\])*$")

def is_extended_pattern(cases):
	for case in cases:
		if not ext_checker.match(case):
			return False
	return True

case_conv_table = get_case_conv_table ()
cases_added = {}

def get_case_conv (cases):
	if all_constructor_patterns(cases):
		return all_constructor_conv(cases)
	
	if is_extended_pattern(cases):
		return extended_pattern_conv(cases)

	if cases in case_conv_table:
		return case_conv_table[cases]
	
	return None

constructor_conv_table = {
	'Just': 'Some',
	'Nothing': 'None',
	'Left': 'Inl',
	'Right': 'Inr',
	'PPtr': '(* PPtr *)',
	'Register': '(* Register *)',
	'Word': '(* Word *)',
}

unique_ids_per_file = {}
def get_next_unique_id ():
	global filename
	id = unique_ids_per_file.get(filename, 1)
	unique_ids_per_file[filename] = id + 1
	return id

def all_constructor_conv (cases):
	conv = [('case \\x of', None)]

	for i, pat in enumerate(cases):
		bits = pat.split()
		if bits[0] in constructor_conv_table:
			bits[0] = constructor_conv_table[bits[0]]
		for j, bit in enumerate(bits):
			if j > 0 and bit == '_':
				bits[j] = 'v%d' % get_next_unique_id()
		pat = ' '.join(bits)
		if i == 0:
			conv.append (('  %s \<Rightarrow> ' % pat, i))
		else:
			conv.append (('| %s \<Rightarrow> ' % pat, i))
	return conv

word_getter = re.compile (r"([a-zA-Z]+)")

record_getter = re.compile (r"([a-zA-Z]+\s*{[a-zA-Z0-9'\s=\,_\(\):\]\[]*})")

def extended_pattern_conv(cases):
	conv = [('case \\x of', None)]

	for i, pat in enumerate(cases):
		pat = '#'.join(pat.split(':'))
		while record_getter.search(pat):
			[left, record, right] = record_getter.split(pat)
			record = reduce_record_pattern (record)
			pat = left + record + right
		if '{' in pat:
			print pat
		assert '{' not in pat
		bits = word_getter.split(pat)
		bits = [constructor_conv_table.get(bit, bit) for bit in bits]
		pat = ''.join(bits)
		if i == 0:
			conv.append (('  %s \<Rightarrow> ' % pat, i))
		else:
			conv.append (('| %s \<Rightarrow> ' % pat, i))
	return conv

def reduce_record_pattern(string):
	assert string[-1] == '}'
	string = string[:-1]
	[left, right] = string.split('{')
	cons = left.strip()
	right = braces.str(right, '(', ')')
	eqs = right.split(',')
	uses = {}
	for eq in eqs:
		eq = str(eq).strip()
		if eq:
			[left, right] = eq.split('=')
			(left, right) = (left.strip(), right.strip())
			if len (right.split()) > 1:
				right = '(%s)' % right
			uses[left] = right
	if cons not in all_constructor_args:
		sys.stderr.write('FAIL: trying to build case for %s\n' % cons)
		sys.stderr.write('when reading %s\n' % filename)
		sys.stderr.write('but constructor not seen yet\n')
		sys.stderr.write('perhaps parse in different order?\n')
		sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n')
		sys.exit(1)
	args = all_constructor_args[cons]
	args = [uses.get(name, '_') for (name, type) in args]
	return cons + ' ' + ' '.join(args)

def subs_nums_and_x (conv, x):
	ids = []

	result = []
	for (line, num) in conv:
		line = x.join (line.split ('\\x'))
		bits = line.split ('\\v')
		line = bits[0]
		for bit in bits[1:]:
			bits = bit.split('\\', 1)
			n = int(bits[0])
			while n >= len(ids):
				ids.append (get_next_unique_id())
			line = line + 'v%d' % (ids[n])
			if len(bits) > 1:
				line = line + bits[1]
		result.append ((line, num))

	return result

def get_supplied_transform_table ():
	f = open ('supplied')

	lines = [line.rstrip() for line in f]
	f.close()

	lines = [(line, n + 1) for (n, line) in enumerate (lines)]
	lines = [(line, n) for (line, n) in lines if line != '']

	for line in lines:
		if '\t' in line:
			sys.stderr.write ('WARN: tab character in supplied')
	
	tree = offside_tree (lines)

	result = {}

	for line, n, children in tree:
		if (not 'conv:' in line) or len(children) != 2:
			sys.stderr.write ('WARN: supplied line %d dropped\n' \
				% n)
			if not 'conv:' in line:
				sys.stderr.write ('\t\t(token "conv:" missing)\n')
			if len(children) != 2:
				sys.stderr.write ('\t\t(%d children != 2)\n' % len(children))
			continue

		children = discard_n (children)

		[before, after] = children

		before = convert_to_stripped_tuple (before[0], before[1])

		result[before] = after

	return result

def print_tree (tree, indent = 0):
	for line, children in tree:
		print ('\t' * indent) + line.strip()
		print_tree (children, indent + 1)

supplied_transform_table = get_supplied_transform_table ()
supplied_transforms_usage = dict ([(key, 0) for key in
				supplied_transform_table.iterkeys()])

def warn_supplied_usage ():
	for (key, usage) in supplied_transforms_usage.iteritems():
		if not usage:
			sys.stderr.write ('WARN: supplied conv unused: %s\n' \
					% key[0])

quotes_getter = re.compile('"[^"]+"')

def detect_recursion (body):
	"""Detects whether any of the bodies of the definitions of this
	function recursively refer to it."""
	single_lines = [reduce_to_single_line(elt) for elt in body]
	single_lines = [''.join(quotes_getter.split(l)) for l in single_lines]
	bits = [line.split(None, 1) for line in single_lines]
	name = bits[0][0]
	assert [n for (n, _) in bits if n != name] == []
	return [body for (n, body) in bits if name in body] != []

def primrec_transform (d):
	sig = d.get('sig', None)
	defn = d['defined']
	body = []
	is_not_first = False
	for (l, c) in d['body']:
		[(l, c)] = body_transform([(l, c)], defn, sig, nopattern=True)
		if is_not_first:
			l = "| " + l
		else:
			l = "  " + l
			is_not_first = True
		l = l.split('\<equiv>')
		assert len(l) == 2
		l = '= ('.join(l)
		(l, c) = remove_trailing_string ('"', (l, c))
		(l, c) = add_trailing_string (')"', (l, c))
		body.append((l, c))
	d['primrec'] = True
	d['body'] = body
	return d

variable_name_regex = re.compile(r"^[a-z]\w*$")

def is_variable_name (string):
	return variable_name_regex.match(string)

def pattern_match_transform (body):
	"""Converts a body containing possibly multiple definitions
	and containing pattern matches into a normal Isabelle definition
	followed by a big Haskell case expression which is resolved
	elsewhere."""
	splits = []
	for (line, children) in body:
		string = braces.str (line, '(', ')')
		while len(string.split('=')) == 1:
			if len(children) == 1:
				[(moreline, children)] = children
				string = string + ' ' + moreline.strip()
			elif children and leading_bar.match (children[0][0]):
				string = string + ' ='
				children = \
					guarded_body_transform (children, ' = ')
			elif children and children[0][1] == []:
				(moreline, _) = children.pop(0)
				string = string + ' ' + moreline.strip()
			else:
				print
				print line
				print
				for child in children:
					print child
				assert 0
				
		[lead, tail] = string.split('=', 1)
		bits = lead.split()
		unbraced = bits
		function = str(bits[0])
		splits.append ((bits[1:], unbraced[1:], tail, children))
	
	common = splits[0][0][:]
	for i, term in enumerate (common):
		if term.startswith('('):
			common[i] = None
		if '@' in term:
			common[i] = None
		if term[0].isupper():
			common[i] = None
	
	for (bits, _, _, _) in splits[1:]:
		for i, term in enumerate (bits):
			if i >= len(common):
				print_tree(body)
			if term != common[i]:
				is_var = is_variable_name(str(term))
				if common[i] == '_' and is_var:
					common[i] = term
				elif term != '_':
					common[i] = None
	
	for i, term in enumerate (common):
		if term == '_':
			common[i] = 'x%d' % i
	
	blanks = [i for (i, n) in enumerate (common) if n == None]

	line = '%s ' % function	
	for i, name in enumerate (common):
		if name == None:
			line = line + 'x%d ' % i
		else:
			line = line + '%s ' % name
	if blanks == []:
		print splits
		print common
	if len(blanks) == 1:
		line = line + '= case x%d of' % blanks[0]
	else:
		line = line + '= case (x%d' % blanks[0]
		for i in blanks[1:]:
			line = line + ', x%d' % i
		line = line + ') of'

	children = []
	for (bits, unbraced, tail, c) in splits:
		if len(blanks) == 1:
			l = '  %s' % unbraced[blanks[0]]
		else:
			l = '  (%s' % unbraced[blanks[0]]
			for i in blanks[1:]:
				l = l + ', %s' % unbraced[i]
			l = l + ')'
		l = l + ' -> %s' % tail
		children.append ((l, c))
	
	return [(line, children)]

def get_lambda_body_lines (d):
	"""Returns lines equivalent to the body of the function as
	a lambda expression."""
	fn = d['defined']

	[(line, children)] = d['body']

	line = line[1:]
	# find \<equiv> in first or 2nd line
	if '\<equiv>' not in line and '\<equiv>' in children[0][0]:
		(l, c) = children[0]
		children = c + children[1:]
		line = line + l
	[lhs, rhs] = line.split('\<equiv>', 1)
	bits = lhs.split()
	args = bits[1:]
	assert fn in bits[0]

	line = '(\<lambda>' + ' '.join(args) + '. ' + rhs
	# lines = ['(* body of %s *)' % fn, line] + flatten_tree (children)
	lines = [line] + flatten_tree (children)
	assert(lines[-1].endswith('"'))
	lines[-1] = lines[-1][:-1] + ')'

	return lines

def add_trailing_string (s, (line, children)):
	if children == []:
		return (line + s, children)
	else:
		modified = add_trailing_string (s, children[-1])
		return (line, children[0:-1] + [modified])

def remove_trailing_string (s, (line, children), _handled = False):
	if not _handled:
		try:
			return remove_trailing_string (s, (line, children),
						_handled = True)
		except:
			sys.stderr.write ('handling %s\n' % ((line, children),))
			raise
	if children == []:
		if not line.endswith (s):
			sys.stderr.write ('ERR: expected %r\n' % line)
			sys.stderr.write ('to end with %r\n' % s)
			assert line.endswith (s)
		n = len (s)
		return (line[: - n], [])
	else:
		modified = remove_trailing_string (s, children[-1],
							_handled=True)
		return (line, children[0:-1] + [modified])

def get_trailing_string (n, (line, children)):
	if children == []:
		return line[-n:]
	else:
		return get_trailing_string (n, children[-1])

def has_trailing_string (s, (line, children)):
	if children == []:
		return line.endswith(s)
	else:
		return has_trailing_string (s, children[-1])

def ensure_type_ordering (defs):
	typedefs = [d for d in defs if d['type'] == 'newtype']
	other = [d for d in defs if d['type'] != 'newtype']

	final_typedefs = []
	while typedefs:
		try:
			i = 0
			deps = typedefs[i]['typedeps']
			while 1:
				for j, term in enumerate (typedefs):
					if term['typename'] in deps:
						break
				else:
					break
				i = j
				deps = typedefs[i]['typedeps']
			final_typedefs.append (typedefs.pop(i))
		except Exception, e:
			print 'Exception hit ordering types:'
			for td in typedefs:
				print '  - %s' % td['typename']
			raise e
	
	return final_typedefs + other

def lead_ws (string):
	amount = len(string) - len(string.lstrip())
	return string[:amount]

def adjust_ws ((line, children), n):
	if n > 0:
		line = ' ' * n + line
	else:
		x = - n
		line = line [x:]
	
	return (line, [adjust_ws(child, n) for child in children])

modulename = re.compile (r"(\w+\.)+")

def perform_module_redirects (lines, call):
	return [subst_module_redirects(line, call) for line in lines]

def subst_module_redirects (line, call):
	m = modulename.search(line)
	if not m:
		return line
	module = line[m.start():m.end() - 1]
	before = line[:m.start()]
	after = line[m.end():]
	after = subst_module_redirects(after, call)
	if module in call.moduletranslations:
		module = call.moduletranslations[module]
	return before + module + '.' + after

